import time
from cdalib.packet import Packet, PacketType
from cdalib.commands import Command
from struct import pack, unpack
from serial import Serial, PARITY_NONE, STOPBITS_ONE, EIGHTBITS
from serial.serialutil import SerialException
from io import BufferedReader
import math
import os
import sys


class Channel:
    def write(self, data: bytes):
        raise RuntimeError("Unimplemented")

    def read(self, size):
        raise RuntimeError("Unimplemented")

    def write_read(self, data: bytes, size):
        raise RuntimeError("Unimplemented")


class COMChannel(Channel):
    def __init__(self, device: Serial):
        self.device = device

    def write(self, data: bytes):
        self.device.write(data)

    def read(self, size):
        x = self.device.read(size)
        if len(x) != size:
            raise RuntimeError('Partial read expected %d bytes got %d' % (size, len(x)))
        return x

    def write_read(self, data: bytes, size):
        self.device.write(data)
        return self.device.read(size)


class Device():
    USB_VID = 0x18D1
    USB_PID = 0xDEED
    BLOCK_SIZE = 1024
    MMC_BLOCK_SIZE = 512

    def connect_to_com(self, port, wait: bool):
        repeat = True
        while repeat:
            if not wait:
                repeat = False
            try:
                self.device = Serial(
                    port, 115200,
                    parity=PARITY_NONE,
                    stopbits=STOPBITS_ONE,
                    bytesize=EIGHTBITS,
                    xonxoff=False,
                    rtscts=False,
                    dsrdtr=False
                )
                self.ch = COMChannel(self.device)
                repeat = False
                if wait:
                    time.sleep(1)
            except SerialException as ex:
                if not wait:
                    print('{}'.format(ex))

    def __init__(self, port, wait: bool):
        if port is not None:
            self.connect_to_com(port, wait)

        self.setup_device()
        self.execute_command(Command.CMD_ENABLE_CONSOLE)

    def setup_device(self):
        # Packet(PacketType.PT_ABORT).write(self.ch)
        Packet(PacketType.PT_SETUP).write(self.ch)
        packet = Packet(None)
        packet.read(self.ch)
        if packet.type != PacketType.PT_STATUS:
            raise RuntimeError(
                'protocol error, unexpected packet {}'.format(packet.type))
        # print('setup status = %#x' % packet.status)

    def execute_command(self, command: Command, args: bytes = b'', status=True):
        data = pack("<I", command.value) + args
        Packet(PacketType.PT_COMMAND, data=data).write(self.ch)
        if status:
            done = False
            while not done:
                packet = Packet(None)
                packet.read(self.ch)

                if self.process_packet(packet) == True:
                    continue

                if packet.type != PacketType.PT_STATUS and packet.type != PacketType.PT_STATUS_DATA:
                    raise RuntimeError(
                        'protocol error, unexpected packet type {}'.format(packet.type))

                if packet.status != 0:
                    # if packet.status == 1:
                    #    self.setup_device()
                    #    return self.execute_command(command, args, status)
                    # else:
                    raise RuntimeError(
                        'command execution failed: status={:x}'.format(packet.status))

                if packet.type == PacketType.PT_STATUS_DATA:
                    return packet.data

                done = True
        return None

    def reboot(self):
        try:
            self.execute_command(Command.CMD_REBOOT)
        except:
            pass

    def attach_console(self):
        self.execute_command(Command.CMD_ENABLE_CONSOLE)

        while True:
            packet = Packet(None)
            packet.read(self.ch)
            if packet.type != PacketType.PT_CONSOLE:
                continue
            try:
                decoded = packet.data.decode("utf-8")
                print(decoded, end='', flush=True)
            except:
                sys.stdout.buffer.write(packet.data)

    def process_packet(self, packet: Packet):
        if packet.type == PacketType.PT_CONSOLE:
            try:
                decoded = packet.data.decode('utf-8')
                print(decoded, end='', flush=True)
            except:
                sys.stdout.buffer.write(packet.data)
            return True

        return False

    def efuse(self, verbose: bool):
        self.execute_command(Command.CMD_PRINT_EFUSE_CFG)
        done = False
        while not done:
            packet = Packet(None)
            packet.read(self.ch)
            if packet.type == PacketType.PT_CONSOLE:
                print(packet.data.decode("utf-8"), end='', flush=True)
            elif packet.type == PacketType.PT_STATUS:
                done = True
                assert packet.status == 0

    def dump_bootinfo(self):
        self.execute_command(Command.CMD_PRINT_BOOTINFO)
        done = False
        while not done:
            packet = Packet(None)
            packet.read(self.ch)
            if packet.type == PacketType.PT_CONSOLE:
                print(packet.data.decode("utf-8"), end='', flush=True)
            elif packet.type == PacketType.PT_STATUS:
                done = True
                assert packet.status == 0

    def send_payload(self, payload, block_size, num_blocks):
        while payload:
            data = payload[:block_size]
            packet = Packet(PacketType.PT_DATA,
                            data=data)
            packet.write(self.ch)
            payload = payload[block_size:]

        done = False
        while not done:
            packet = Packet(None)
            packet.read(self.ch)
            if packet.type == PacketType.PT_CONSOLE:
                print(packet.data.decode("utf-8"), end='', flush=True)
            elif packet.type == PacketType.PT_STATUS:
                if packet.status != 0:
                    raise RuntimeError(
                        'command failed: {}'.format(packet.status))
                done = True

    def send_lk(self, p):
        if isinstance(p, str):
            file = open(p, 'rb')
            payload = file.read()
            file.close()
        elif isinstance(p, BufferedReader):
            payload = p.read()
        elif isinstance(p, bytes):
            payload = p
        else:
            raise ValueError('invalid argument to send_lk')

        if len(payload) % Device.BLOCK_SIZE != 0:
            payload += (b'\x00' * (Device.BLOCK_SIZE -
                                   (len(payload) % Device.BLOCK_SIZE)))

        num_blocks = math.floor(len(payload) / Device.BLOCK_SIZE)
        if len(payload) % Device.BLOCK_SIZE != 0:
            raise ValueError('size must be multiple of block size')

        print('sending LK ...', flush=True)

        self.execute_command(Command.CMD_SEND_LK, args=pack(
            "<II", *(Device.BLOCK_SIZE, num_blocks)))

        self.send_payload(payload, Device.BLOCK_SIZE, num_blocks)

    def boot_lk(self):
        self.execute_command(Command.CMD_BOOT_LK)

    def send_atf(self, p):
        if isinstance(p, str):
            file = open(p, 'rb')
            payload = file.read()
            file.close()
        elif isinstance(p, BufferedReader):
            payload = p.read()
        elif isinstance(p, bytes):
            payload = p
        else:
            raise ValueError('invalid argument to send_atf')

        if len(payload) % Device.BLOCK_SIZE != 0:
            payload += (b'\x00' * (Device.BLOCK_SIZE -
                                   (len(payload) % Device.BLOCK_SIZE)))

        num_blocks = math.floor(len(payload) / Device.BLOCK_SIZE)
        if len(payload) % Device.BLOCK_SIZE != 0:
            raise ValueError('size must be multiple of block size')

        print('sending ATF ...', flush=True)

        self.execute_command(Command.CMD_SEND_ATF, args=pack(
            "<II", *(Device.BLOCK_SIZE, num_blocks)))

        self.send_payload(payload, Device.BLOCK_SIZE, num_blocks)

    def boot_mtk_da(self, p, use_lk: bool):
        if isinstance(p, str):
            file = open(p, 'rb')
            payload = file.read()
            file.close()
        elif isinstance(p, BufferedReader):
            payload = p.read()
        elif isinstance(p, bytes):
            payload = p
        else:
            raise ValueError('invalid argument to send_tee')

        if len(payload) % Device.BLOCK_SIZE != 0:
            payload += (b'\x00' * (Device.BLOCK_SIZE -
                                   (len(payload) % Device.BLOCK_SIZE)))

        num_blocks = math.floor(len(payload) / Device.BLOCK_SIZE)
        if len(payload) % Device.BLOCK_SIZE != 0:
            raise ValueError('size must be multiple of block size')

        self.execute_command(Command.CMD_SEND_MTK_DA, args=pack(
            "<III", *(Device.BLOCK_SIZE, num_blocks, len(payload))))

        self.send_payload(payload, Device.BLOCK_SIZE, num_blocks)
        self.execute_command(Command.CMD_BOOT_MTK_DA, args=pack("<I", use_lk))

    def get_logs(self):
        self.execute_command(Command.CMD_GET_LOGS)
        done = False
        while not done:
            packet = Packet(None)
            packet.read(self.ch)
            if packet.type == PacketType.PT_CONSOLE:
                print(packet.data.decode("utf-8"), end='', flush=True)
            elif packet.type == PacketType.PT_STATUS:
                done = True
                assert packet.status == 0

    def mmc_init(self, n: int):
        self.execute_command(Command.CMD_MMC_INIT, args=pack('<I', n))

    def secpol_dump(self):
        self.execute_command(Command.CMD_SECPOL_DUMP)

    def secpol_lock(self):
        self.execute_command(Command.CMD_SECPOL_LOCK)

    def secpol_unlock(self):
        self.execute_command(Command.CMD_SECPOL_UNLOCK)

    def partition_read(self, partition_name, output):
        if isinstance(partition_name, str):
            partition_name = bytes(partition_name, 'utf-8')
        else:
            raise ValueError()

        file = open(output, 'wb')

        offset = pack('<Q', 0)
        length = pack('<Q', 0xffffffff)
        mmc = pack('<I', 0)

        response = self.execute_command(Command.CMD_READ_P,
                                        args=offset+length+mmc+partition_name)

        if len(response) != 8:
            raise RuntimeError('protocol error')

        r_length = unpack('<Q', response)[0]
        #print('Receiving {} bytes'.format(r_length))

        left = r_length
        while left > 0:
            packet = Packet(None)
            packet.read(self.ch)
            if not self.process_packet(packet):
                if packet.type == PacketType.PT_DATA:
                    # print('.',flush=True,end='')
                    file.write(packet.data)
                    left -= len(packet.data)
                    #print('%d bytes left' % left, flush=True)
                else:
                    print('unexpected packet type {}'.format(packet.type))

        file.close()

        packet = Packet(None)
        packet.read(self.ch)
        if (packet.type == PacketType.PT_STATUS or packet.type == PacketType.PT_STATUS_DATA) and packet.status != 0:
            print('status = %d' % packet.status, flush=True)

    def partition_write(self, partition_name, input):
        if isinstance(partition_name, str):
            partition_name = bytes(partition_name, 'utf-8')
        else:
            raise ValueError()

        if isinstance(input, str):
            is_file = True
        elif isinstance(input, bytes):
            is_file = False
        else:
            raise ValueError()

        if is_file:
            data_size = os.path.getsize(input)
            file = open(input, 'rb')
        else:
            data_size = len(input)

        offset = pack('<Q', 0)
        length = pack('<Q', data_size)
        mmc = pack('<I', 0)

        if data_size % Device.MMC_BLOCK_SIZE != 0:
            raise RuntimeError('unimplemented')

        self.execute_command(Command.CMD_WRITE_P,
                             args=offset+length+mmc+partition_name)
        left = data_size
        while left > 0:
            if is_file:
                packet = Packet(PacketType.PT_DATA,
                                data=file.read(Device.MMC_BLOCK_SIZE))
            else:
                packet = Packet(PacketType.PT_DATA,
                                data=input[:Device.MMC_BLOCK_SIZE])
                input = input[Device.MMC_BLOCK_SIZE:]
            packet.write(self.ch)
            left -= Device.MMC_BLOCK_SIZE
            print('%d bytes left     ' % left, flush=True, end='\r')

    def run_test(self, test_id: int):
        self.execute_command(Command.CMD_RUN_TEST, args=pack('<I', test_id))

    def send_pl(self, p):
        if isinstance(p, str):
            file = open(p, 'rb')
            payload = file.read()
            file.close()
        elif isinstance(p, BufferedReader):
            payload = p.read()
        elif isinstance(p, bytes):
            payload = p
        else:
            raise ValueError('invalid argument to send_lk')

        if len(payload) % Device.BLOCK_SIZE != 0:
            payload += (b'\x00' * (Device.BLOCK_SIZE -
                                   (len(payload) % Device.BLOCK_SIZE)))

        num_blocks = math.floor(len(payload) / Device.BLOCK_SIZE)
        if len(payload) % Device.BLOCK_SIZE != 0:
            raise ValueError('size must be multiple of block size')

        self.execute_command(Command.CMD_SEND_PL, args=pack(
            "<II", *(Device.BLOCK_SIZE, num_blocks)))

        self.send_payload(payload, Device.BLOCK_SIZE, num_blocks)

    def boot_pl(self):
        self.execute_command(Command.CMD_BOOT_PL)

    def memory_read(self, address, block_size, num_blocks):
        self.execute_command(Command.CMD_READ_MEMORY, args=pack(
            '<I', *(address, block_size, num_blocks)))

        buffer = b''

        for _ in range(num_blocks):
            packet = Packet(None)
            packet.read(self.ch)
            if not self.process_packet(packet):
                if packet.type == PacketType.PT_DATA:
                    buffer += packet.data

    def enter_bootrom(self):
        self.execute_command(Command.CMD_ENTER_BOOTROM)
